import os
from tool import log_tool
from tool.config_base import BaseConfig


class Config(BaseConfig):
    # 调用对象
    model_name = 'xlnet'
    data_name = 'data_iflytek'

    # 训练模型输出目录
    folder = os.environ.get('folder', 'data/m5_xlnet_f1')

    # 模型参数
    sweet_add_layers = [5, 7, 9, 11]  # 增量网络层数

    # 训练参数
    batch_size = 1
    lr = 1e-5
    max_epoch = 10
    sweet_threshold = 0.8413  # 甜蜜点阈值
    revise_threshold = 0.95  # 复习阈值
    right_sampling_rate = 0.1  # 构建学习集的时候 对正确数据的采样率

    # --------------------------------------------------------------
    # 数据参数
    pretrain_model_dir = '/data/wxy/pretrain_model_file/xlnet/chinese_small_xlnet'
    if not os.path.isdir(pretrain_model_dir):
        pretrain_model_dir = 'E:/code/data/pretrain_model_file/xlnet/chinese_small_xlnet'
        log_tool.Logs.tmp.info(f'change pretrain_model_dir={pretrain_model_dir}')
    sp_path = os.path.join(pretrain_model_dir, 'spiece.model')

    vocab_size = 32000  # 和xlnet的n_token一致
    pad_id = 5
    max_len = 1000  # 最大字数
    data_dir = '/data/wxy/common/CLUE/iflytek_public'
    if not os.path.isdir(data_dir):
        data_dir = 'E:/code/data/common/CLUE/iflytek_public'
